package mysql

import (
	"context"
	"fmt"
	_ "github.com/go-sql-driver/mysql"
	"time"
	"xorm.io/xorm"
)

type Xorm struct {
	*xorm.EngineGroup
}

// NewXorm  EngineGroup
func NewXorm(cfg Config) *Xorm {
	master := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&loc=%v", cfg.Master.Username,
		cfg.Master.Password, cfg.Master.HostPort, cfg.Master.Database, "Asia%2fShanghai")
	conns := []string{master}
	for _, slave := range cfg.Slaves {
		conns = append(conns, fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&loc=%v", slave.Username,
			slave.Password, slave.HostPort, slave.Database, "Asia%2fShanghai"))
	}
	engine, err := xorm.NewEngineGroup("mysql", conns)
	if err != nil {
		panic(err)
	}

	engine.SetMaxOpenConns(cfg.MaxOpenConn)
	engine.SetMaxIdleConns(cfg.MaxIdleConn)
	engine.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second)
	engine.ShowSQL(cfg.ShowSQL)

	if err := engine.Ping(); err != nil {
		panic(err)
	}
	x := &Xorm{
		EngineGroup: engine,
	}
	return x
}

// TransactionWithSession  support transaction nested
// if tx is started, sess will not commit
func (x *Xorm) TransactionWithSession(sess *xorm.Session, tx func(sess *xorm.Session) error) error {
	if x.isStartTx(sess) {
		// exec, be equal to steps
		return tx(sess)
	}

	defer func() {
		_ = sess.Close()
	}()

	if err := sess.Begin(); err != nil {
		return err
	}

	if err := tx(sess); err != nil {
		_ = sess.Rollback()
		return err
	}
	return sess.Commit()
}
func (x *Xorm) isStartTx(sess *xorm.Session) bool {
	lastSql, _ := sess.LastSQL()
	return lastSql == "BEGIN TRANSACTION" || lastSql == "ROLL BACK" || lastSql == "COMMIT"
}

// Transaction master transaction
func (x *Xorm) Transaction(tx func(sess *xorm.Session) error) error {
	_, err := x.Engine.Transaction(func(session *xorm.Session) (interface{}, error) {
		return nil, tx(session)
	})
	return err
}
func (x *Xorm) MysqlTransaction(tx func(sess *xorm.Session) error) error {
	return x.Transaction(tx)
}
func (x *Xorm) TxInsert(sess *xorm.Session, bean interface{}) error {
	var (
		affected int64
		err      error
	)
	if sess != nil {
		affected, err = sess.Insert(bean)
	} else {
		affected, err = x.Context(context.Background()).Insert(bean)
	}

	if err != nil {
		return err
	}
	if affected != 1 {
		return fmt.Errorf("wrong affected:%d", affected)
	}
	return nil
}

// TxInsertIng 有些字段进行忽略不添加的的处理
func (x *Xorm) TxInsertIng(sess *xorm.Session, bean interface{}, omit []string) error {
	var (
		affected int64
		err      error
	)
	if sess == nil {
		sess = x.Context(context.Background())
	}
	if len(omit) > 0 {
		affected, err = sess.Omit(omit...).Insert(bean)
	} else {
		affected, err = sess.Insert(bean)
	}
	if err != nil {
		return err
	}
	if affected != 1 {
		return fmt.Errorf("wrong affected:%d", affected)
	}
	return nil
}

// TxUpdate 修改指定的指定字段
func (x *Xorm) TxUpdate(sess *xorm.Session, id int64, bean interface{}, cols []string) error {
	var (
		affected int64
		err      error
	)
	if sess == nil {
		sess = x.Context(context.Background())
	}
	sess = sess.ID(id)
	if cols != nil && len(cols) > 0 {
		sess.Cols(cols...)
	}
	affected, err = sess.Update(bean)
	if err != nil {
		return err
	}
	if affected != 1 {
		return fmt.Errorf("wrong affected:%d", affected)
	}
	return nil
}

// TxUpdateMust 必须修改的字段，来处理那些空值的情况
func (x *Xorm) TxUpdateMust(sess *xorm.Session, id int64, bean interface{}, mustCol []string) error {
	var (
		affected int64
		err      error
	)
	if sess == nil {
		sess = x.Context(context.Background())
	}
	sess = sess.ID(id)
	if mustCol != nil && len(mustCol) > 0 {
		sess.MustCols(mustCol...)
	}
	affected, err = sess.Update(bean)
	if err != nil {
		return err
	}
	if affected != 1 {
		return fmt.Errorf("wrong affected:%d", affected)
	}
	return nil
}

func (x *Xorm) TxUpdateByCondition(sess *xorm.Session, id int64, bean interface{}, cols []string, con string) error {
	var (
		err error
	)
	if sess == nil {
		sess = x.Context(context.Background())
	}
	sess = sess.ID(id)
	if cols != nil && len(cols) > 0 {
		sess.Cols(cols...)
	}
	if con != "" {
		sess.And(con)
	}
	_, err = sess.Update(bean)
	if err != nil {
		return err
	}
	return nil
}
func (x *Xorm) TxBeanUpdate(sess *xorm.Session, id int64, bean interface{}) error {
	var (
		affected int64
		err      error
	)
	if sess != nil {
		affected, err = sess.ID(id).Update(bean)
	} else {
		affected, err = x.Context(context.Background()).ID(id).Update(bean)
	}
	if err != nil {
		return err
	}
	if affected != 1 {
		return fmt.Errorf("wrong affected:%d", affected)
	}
	return nil
}
func (x *Xorm) TxExecSql(sess *xorm.Session, sql string, val ...interface{}) error {
	if sess != nil {
		if _, err := sess.Exec(sql, val); err != nil {
			return fmt.Errorf(err.Error())
		}
	} else {
		if _, err := x.Context(context.Background()).Exec(sql, val); err != nil {
			return fmt.Errorf(err.Error())
		}
	}
	return nil
}
func (x *Xorm) MyCount(bean interface{}, param map[string]interface{}) (int64, error) {
	sess := x.Context(context.Background())
	if len(param) == 0 {
		return 0, nil
	}
	for key, val := range param {
		sess.And(fmt.Sprintf("%v = ?", key), val)
	}
	row, err := sess.Count(bean)
	return row, fmt.Errorf(err.Error())
}
func (x *Xorm) TxBatchDelete(sess *xorm.Session, ids []int64, bean interface{}) error {
	if sess == nil {
		sess = x.Context(context.Background())
	}
	_, err := sess.In("id", ids).Delete(bean)
	if err != nil {
		return err
	}
	return nil
}
func (x *Xorm) TxDelete(sess *xorm.Session, id int64, bean interface{}) error {
	var (
		affected int64
		err      error
	)
	if sess != nil {
		affected, err = sess.ID(id).Delete(bean)
	} else {
		affected, err = x.Context(context.Background()).ID(id).Delete(bean)
	}

	if err != nil {
		return err
	}
	if affected != 1 {
		return fmt.Errorf("wrong affected:%d", affected)
	}
	return nil
}
func (x *Xorm) TxClear(sess *xorm.Session, bean interface{}) error {
	var (
		err error
	)
	if sess != nil {
		_, err = sess.Where("1=1").Delete(bean)
	} else {
		_, err = x.Context(context.Background()).Where("1=1").Delete(bean)
	}

	return err
}
